{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "09ee9f88",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "# Textual Inversion with Determined\n",
    "\n",
    "This notebook generates images from the trained textual inversion models generated with the `detsd.DetSDTextualInversionTrainer` class and saved as Determined checkpoints.  This notebook should be connected to a GPU.\n",
    "\n",
    "### Pre-Launch Setup\n",
    "\n",
    "A [Huggingface User Access Token](https://huggingface.co/docs/hub/security-tokens) is required to download the [Stable Diffusion weights](https://huggingface.co/CompVis/stable-diffusion-v1-4). To use this notebook, please modify the following lines in the `detsd-notebook.yaml` file:\n",
    "```yaml\n",
    "environment:\n",
    "    environment_variables:\n",
    "        - HF_AUTH_TOKEN=YOUR_HF_AUTH_TOKEN_HERE\n",
    "```\n",
    "after which this notebook can be launched by calling the below from the root of the repo directory\n",
    "```bash\n",
    "det -m MASTER_URL_WITH_PORT notebook start --config-file detsd-notebook.yaml --context .\n",
    "```\n",
    "to load the entire repo into the JupyterLab instance, or \n",
    "```bash\n",
    "det -m MASTER_URL_WITH_PORT notebook start --config-file detsd-notebook.yaml -i detsd -i startup-hook.sh -i learned_embeddings_dict_demo.pt -i textual_inversion.ipynb\n",
    "```\n",
    "to only load the files relevant to this notebook.  After either option above, the copy of `textual_inversion.ipynb` on the master can be opened and run."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "70ef7856",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Update the jupyter notebook for better progress-bar rendering (may require a re-load to take effect). If the notebook was launched using the above command, other dependencies were already installed upon agent start-up through the repo's `startup-hook.sh` script."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "243063f8",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "! pip install -qq jupyterlab-widgets==1.1.1 ipywidgets==7.7.2"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c426e42f",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Creating the Pipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3e5ad4d0",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Import the `DetSDTextualInversionPipeline` class from `detsd.py` (loaded via the `--context` flag above), which will be used to generate Stable Diffusion images."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8dd88c3a",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "from detsd import DetSDTextualInversionPipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10b3dc72",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Instantiate the pipeline with the default arguments:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f6a00b1d",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "detsd_pipeline = DetSDTextualInversionPipeline()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "faf1be4a",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Note: `DetSDTextualInversionPipeline` is initialized with `use_fp16=True` by default which increases inference speed and reduces memory usage, at the cost of somewhat reduced-quality images.  All available args can be viewed by uncommenting and running the cell below"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "865ef6a6",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# ? DetSDTextualInversionPipeline"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "38cb20a2",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Load Determined Checkpoints"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "058eedd9",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "We can now load textual-inversion checkpoints into the model. They are assumed to have been trained with `DetSDTextualInversionTrainer`, also contained in `detsd.py`.  These Determined checkpoints can be specified by their uuid, assuming all such checkpoints exist on the master we are currently logged into."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7fe8adc1",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# Code for logging into the master, if not already logged in.\n",
    "# Not required if notebook was launched as described above.\n",
    "\n",
    "# from determined.experimental import client\n",
    "# client.login(master=MASTER_URL_WITH_PORT, user=USER, password=PASS)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "673aa823",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Fill in the `uuids` list below with the `uuid` `str` values of any Determined checkpoints you wish to incorporate into the model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "912e0a98",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "uuids = []\n",
    "detsd_pipeline.load_from_uuids(uuids)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b9079e23-1544-49c1-801d-133c54a04f13",
   "metadata": {},
   "source": [
    "A sample embedding is also included in the repo (with corresponding concept token `det-logo-demo`) and can be loaded in as follows (assuming the notebook was launched with the `-i learned_embeddings_dict_demo.pt` arg):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4f8ac6b2-9987-4ddb-b83b-b630b3bd7b7e",
   "metadata": {},
   "outputs": [],
   "source": [
    "from os.path import exists\n",
    "demo_concept_path = 'learned_embeddings_dict_demo.pt'\n",
    "if exists(demo_concept_path):\n",
    "    detsd_pipeline.load_from_checkpoint_dir(checkpoint_dir='.', learned_embeddings_filename='learned_embeddings_dict_demo.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "60ae783f",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## Generate Images\n",
    "\n",
    "Finally, let's generate some art."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "58bbee51",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Grab the first concept which was loaded into the pipeline and store it as `first_concept`.  If no concepts were loaded above, fall back to using `brain logo, sharp lines, connected circles, concept art` as a default value for `first_concept`; vanilla Stable Diffusion is being used in this case."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d88578a4",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "all_added_concepts = detsd_pipeline.all_added_concepts\n",
    "if all_added_concepts:\n",
    "    first_concept = all_added_concepts[0]\n",
    "else:\n",
    "    first_concept = 'brain logo, sharp lines, connected circles, concept art'\n",
    "print(f'Using \"{first_concept}\" as first_concept in the below\\n')\n",
    "print(f'All available concepts: {all_added_concepts}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5159f832",
   "metadata": {},
   "source": [
    "Create a directory for saved images and an index for tracking the number of images created."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1e0ce5c",
   "metadata": {},
   "outputs": [],
   "source": [
    "save_dir = 'generated_images'\n",
    "! mkdir {save_dir}\n",
    "num_generated_images = 0"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10a36244",
   "metadata": {},
   "source": [
    "The below code uses creates `batch_size * num_images_per_prompt` total images from the prompt.\n",
    "\n",
    "If you are generating using the demo embedding with `det-logo-demo` as `first_concept`, we recommend setting the guidance scale to a relatively low value, e.g. ~3."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2c85a4c1",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "prompt = f'a watercolor painting on textured paper of a {first_concept} using soft strokes, pastel colors, incredible composition, masterpiece'\n",
    "batch_size = 2\n",
    "num_images_per_prompt = 2\n",
    "\n",
    "generator = torch.Generator(device='cuda').manual_seed(2147483647)\n",
    "output = detsd_pipeline(prompt=[prompt] * batch_size,\n",
    "                        num_images_per_prompt=num_images_per_prompt,\n",
    "                        num_inference_steps=50,\n",
    "                        generator=generator,\n",
    "                        guidance_scale=7.5\n",
    "                       )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "384d0bff",
   "metadata": {},
   "source": [
    "Visualize and save:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0002d5f0",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "from pathlib import Path\n",
    "\n",
    "for img, nsfw in zip(output.images, output.nsfw_content_detected):\n",
    "    # Skip black images which are made when NSFW is detected.\n",
    "    if not nsfw:\n",
    "        num_generated_images += 1\n",
    "        display(img)\n",
    "        img.save(Path(save_dir).joinpath(f'{num_generated_images}.png'))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a035b9c7",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "Explanation the some arguments above:\n",
    "* `num_inference_steps`: how many steps to run the generation process for. ~50 is typical\n",
    "* `guidance_scale`: tunes how much weight is given to the prompt during generation. 7.5 is the default, with larger numbers leading to stronger adherence to the prompt.\n",
    "* `generator`: pass in a fixed `torch.Generator` instance for reproducible results.\n",
    "\n",
    "`DetSDTextualInversionPipeline`'s `__call__` method accepts the same arguments as its underlying Huggingface `StableDiffusionPipeline` instance; see the [Hugging Face documentation](https://huggingface.co/docs/diffusers/main/en/api/pipelines/stable_diffusion#diffusers.StableDiffusionPipeline.__call__) for information on all available arguments."
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "274db8ed5a5b9f1d7e673d5dc8f73328ebbaf45fbf7c788fee56d02d0eb8b109"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
